- Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathVGG19_Training_Module_Final.py
127 lines (101 loc) · 4.61 KB
/
VGG19_Training_Module_Final.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
fromsklearn.model_selectionimporttrain_test_split
fromkeras.utils.np_utilsimportto_categorical
frompylabimport*
fromkeras.callbacksimportLearningRateScheduler
fromkerasimportmodels
fromkerasimportlayers
importnumpyasnp
importmatplotlib.pyplotasplt
fromkeras.applications.vgg19importVGG19
importitertools
fromsklearn.metricsimportconfusion_matrix
fromkerasimportoptimizers
fromPILimportImage
importpandasaspd
importos
importrandom
deftrain_VGG19_Model(csv_file , lr , ep):
defstep_decay_schedule(initial_lr=1e-3, decay_factor=0.75, step_size=10):
defschedule(epoch):
returninitial_lr* (decay_factor**np.floor(epoch/step_size))
returnLearningRateScheduler(schedule)
defRead_image(path):
im=Image.open(path).convert('RGB')
returnim
X= []
Y= []
dataset=pd.read_csv(csv_file)
forindex, rowindataset.iterrows():
X.append(array(Read_image(row[0]).resize((100, 100))).flatten() /255.0)
Y.append(row[1])
X=np.array(X)
Y=to_categorical(Y, 2)
X=X.reshape(-1, 100, 100, 3)
X_train, X_val, Y_train, Y_val=train_test_split(X, Y, test_size=0.20, random_state=5)
# Load the VGG model
vgg_conv=VGG19(weights='imagenet', include_top=False, input_shape=(100, 100, 3))
# Create the model
model=models.Sequential()
# Freeze the layers the first layers
forlayerinvgg_conv.layers[:-5]:
layer.trainable=False
# Check the trainabl status of the individual layers
forlayerinvgg_conv.layers:
print(layer, layer.trainable)
model.add(vgg_conv)
model.summary()
model.add(layers.Flatten())
model.add(layers.Dense(1024, activation='relu'))
model.add(layers.Dropout(0.50))
model.add(layers.Dense(1024, activation='relu'))
model.add(layers.Dropout(0.50))
model.add(layers.Dense(2, activation='softmax'))
optimizer=optimizers.Adagrad(lr=lr, epsilon=None, decay=0.0)
model.compile(optimizer=optimizer,
loss="mean_squared_error",
metrics=["accuracy"])
lr_sched=step_decay_schedule(initial_lr=1e-4, decay_factor=0.75, step_size=2)
epochs=ep
batch_size=20
history=model.fit(X_train, Y_train, batch_size=batch_size, epochs=epochs, validation_data=(X_val, Y_val), verbose=2,callbacks=[lr_sched])
# Plot the loss and accuracy curves for training and validation
fig, ax=plt.subplots(3, 1)
ax[0].plot(history.history['loss'], color='b', label="Training loss")
ax[0].plot(history.history['val_loss'], color='r', label="validation loss", axes=ax[0])
legend=ax[0].legend(loc='best', shadow=True)
ax[1].plot(history.history['acc'], color='b', label="Training accuracy")
ax[1].plot(history.history['val_acc'], color='r', label="Validation accuracy")
legend=ax[1].legend(loc='best', shadow=True)
defplot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks=np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
ifnormalize:
cm=cm.astype('float') /cm.sum(axis=1)[:, np.newaxis]
thresh=cm.max() /2.
fori, jinitertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j], horizontalalignment="center", color="white"ifcm[i, j] >threshelse"black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
# Predict the values from the validation dataset
Y_pred=model.predict(X_val)
# Convert predictions classes to one hot vectors
Y_pred_classes=np.argmax(Y_pred, axis=1)
# Convert validation observations to one hot vectors
Y_true=np.argmax(Y_val, axis=1)
# compute the confusion matrix
confusion_mtx=confusion_matrix(Y_true, Y_pred_classes)
# plot the confusion matrix
plot_confusion_matrix(confusion_mtx, classes=range(2))
image_path=os.getcwd()+"\\Figures"
Models_path=os.getcwd()+"\\Re_Traind_Models"
file_number=random.randint(1, 1000000)
plot_Name=image_path+"\\VGG19_"+str(file_number)+".png"
Model_Name=Models_path+"\\VGG19_"+str(file_number)+".h5"
plt.savefig(plot_Name , transparent=True , bbox_incehs="tight" , pad_inches=2 , dpi=50)
model.save(Model_Name)
returnplot_Name , Model_Name